"""Main entry point for doing all pruning-related stuff. Adapted from https://github.com/arunmallya/packnet/blob/master/src/main.py"""
from __future__ import division, print_function
import gc
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
import numpy as np
import time
import art.attacks.evasion
import mi_estimator
import warnings
# To prevent PIL warnings.
warnings.filterwarnings("ignore")
from torchmetrics import Accuracy
from torchvision import models
import cifar10models
import cifar100models
import data
from torch.autograd import Variable
from tqdm import tqdm
from art.estimators.classification import PyTorchClassifier
from torchsummary import summary
import utils
from manager_lincom import Manager

###General flags
FLAGS = argparse.ArgumentParser()

FLAGS.add_argument('--network', choices=['AlexNet', 'VGG16', 'Resnet50'], help='Architectures')
FLAGS.add_argument('--attacktype', choices=['FGSM', 'C&W', 'PGD', 'I-FGSM', 'DeepFool'], help='Type of adversarial attack used')
FLAGS.add_argument('--num_fb_layers', type=int, default=4, help='Number of layers allocated to subnetwork fb')
FLAGS.add_argument('--dataset', choices=['CIFAR10', 'CIFAR100', 'MNIST', 'Imagenette2'], help='Dataset used for training')
FLAGS.add_argument('--batchsize', type=int, default=512, help='Batch size')
FLAGS.add_argument('--epochs', type=int, default=20, help='Number of training epochs')
FLAGS.add_argument('--learning_rate', type=float, default=1e-2, help='Learning rate')
FLAGS.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
FLAGS.add_argument('--cuda', action='store_true', default=True, help='use CUDA')
FLAGS.add_argument('--avg', choices=["False","True"], help='average lambdas over batches to reduce runtime')

   
######################################################################################################################################################################
###
###     Main function
###
######################################################################################################################################################################



def main():
    args = FLAGS.parse_args()
    torch.cuda.set_device(0)
    print("start")


#########################################################################################
###    Prepare Data and Loaders
#########################################################################################

    datavar = data.Dataset(("../data/" + args.dataset), args.dataset)
    
    print("loaded dataset")
    
#########################################################################################
###    Prepare The Model
#########################################################################################

    layerdict_CIFAR = {
        "AlexNet":[2,5,8,10,12,15],
        "VGG16":[2,5,9,12,16,19,22,26,29,32,36,39,42,48,51,54],
        "Resnet50":[1,7,9,11,15,18,20,22,26,28,30,35,37,39,43,46,48,50,54,56,58,62,64,66,71,73,75,79,82,84,86,90,92,94,98,100,102,106,108,110,114,116,118,123,125,127,131,134,136,138,142,144,146,150] 
    }
    
    layerdict_Imagenette2 = {
        "AlexNet":[2,5,8,10,12,18,21,23],
        "VGG16":[2,5,9,12,16,19,22,26,29,32,36,39,42,48,51,54],
        "Resnet50":[1,7,9,11,15,18,20,22,26,28,30,35,37,39,43,46,48,50,54,56,58,62,64,66,71,73,75,79,82,84,86,90,92,94,98,100,102,106,108,110,114,116,118,123,125,127,131,134,136,138,142,144,146,150] 
    }    
    
    
    ### How are these weight layers picked? For VGG -1,-4, and -7 match the tunable layers in the FC portion, but -15 and -19 dont seem to match the immediately prior weight layers

    if args.dataset=="Imagenette2":
        if args.network == 'AlexNet':
            model = models.AlexNet()
            state_dict = torch.load('../state_dicts/Imagenette2/AlexNet/pretrain_dict.zip')
            model.load_state_dict(state_dict)
        elif args.network == 'VGG16':
            model = models.vgg16_bn()
            state_dict = torch.load('../state_dicts/Imagenette2/VGG16/pretrain_dict.zip')
            model.load_state_dict(state_dict)
        elif args.network == 'Resnet50':
            model = models.resnet50()
            state_dict = torch.load('../state_dicts/Imagenette2/Resnet50/pretrain_dict.zip')
            model.load_state_dict(state_dict)
    elif args.dataset=="CIFAR10":
        if args.network == 'AlexNet':
            model = cifar10models.AlexNet(args)
            print("model generated")

            state_dict = torch.load('../state_dicts/' +  args.dataset + '/model_best.pth.tar')['state_dict']
            # Rename state_dict's keys to match those of model
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            model.load_state_dict(state_dict)
        elif args.network == 'VGG16':
            model = cifar10models.vgg16_bn(args, pretrained=True)
        elif args.network == 'Resnet50':
            model = cifar10models.resnet50(args, pretrained=True)
    elif args.dataset=="CIFAR100":
        if args.network == 'AlexNet':
            model = cifar100models.AlexNet(args, num_classes=100)
            state_dict = torch.load('../state_dicts/CIFAR100/AlexNet/pretrain_dict.zip')
            model.load_state_dict(state_dict)
        elif args.network == 'VGG16':
            model = cifar100models.vgg16_bn(args, pretrained=True)
        elif args.network == 'Resnet50':
            model = cifar100models.resnet50(args, pretrained=True)


    print("model loaded")
    if args.dataset=="Imagenette2":
        ### f_b is the first layer of the non-adversarial subnetwork, so the default is the last layer only
        num_layers = len(layerdict_Imagenette2[args.network])
        weight_layers = layerdict_Imagenette2[args.network]
    elif args.dataset=="CIFAR10" or args.dataset=="CIFAR100":
        ### f_b is the first layer of the non-adversarial subnetwork, so the default is the last layer only
        num_layers = len(layerdict_CIFAR[args.network])
        weight_layers = layerdict_CIFAR[args.network]
        
    f_b_index = num_layers - args.num_fb_layers
    f_b_start = weight_layers[f_b_index] 
    print(f'Set-Up: {f_b_index}, {f_b_start}, {args.attacktype}, {args.network}')
        
    lossfn = torch.nn.CrossEntropyLoss()
    accuracy = Accuracy()    
    optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9, nesterov=True)
    
    
#########################################################################################
###    Prepare Adversarial Data
#########################################################################################
    advdatapath = "../attacks/" + args.network + "/" + args.attacktype +"/" + args.dataset 
    model.cuda()
    
    x_train_adv = np.load((advdatapath + "/x_train_adv.npy"))
    print('x_train_adv loaded.')
    x_test_adv = np.load((advdatapath + "/x_test_adv.npy"))
    print('x_test_adv loaded.')
    y_train = np.load((advdatapath + "/y_train.npy"))
    y_test = np.load((advdatapath + "/y_test.npy"))


    x_train_adv = torch.from_numpy(x_train_adv).float()
    x_test_adv = torch.from_numpy(x_test_adv).float()

    print("Type of x_train_adv is: ", type(x_train_adv))


    traindata_adv = []
    for i in range(len(x_train_adv)):
      traindata_adv.append([x_train_adv[i], y_train[i]])

    testdata_adv = []
    for i in range(len(x_test_adv)):
      testdata_adv.append([x_test_adv[i], y_test[i]])
    
    
    trainloader = datavar.train_dataloader()
    testloader = datavar.test_dataloader()


    advtrainloader = torch.utils.data.DataLoader(traindata_adv, batch_size=args.batchsize, num_workers=1, shuffle=True, drop_last=True, pin_memory=True)
    advtestloader = torch.utils.data.DataLoader(testdata_adv, batch_size=args.batchsize, num_workers=1, shuffle=False, drop_last=True, pin_memory=True)
    



#########################################################################################
###    Train on Adversarial Data
#########################################################################################

    ### manager with pretrained normal model
    manager = Manager(args, model, trainloader, testloader, advtrainloader, advtestloader)


    # Perform finetuning.
    pretrained_adv_acc = manager.eval(adversarial=True, data="Testing")  
    pretrained_adv_acc = 100 - pretrained_adv_acc[0]
    print("pretrained adversarial acc:", pretrained_adv_acc)

    # Perform finetuning.
    pretrained_acc = manager.eval(adversarial=False, data="Testing")  
    pretrained_acc = 100 - pretrained_acc[0]
    print("pretrained acc:", pretrained_acc)

    print("\n\n\n")

    root_save_path = './saves/' + args.network + "/" + args.attacktype + "/" + args.dataset +"/"
    normal_path = root_save_path + 'normal_' + str(f_b_start)
    os.makedirs(root_save_path,exist_ok=True)    
    manager.save_model(normal_path)

    trt = time.time()
    manager.train(args.epochs, optimizer, adversarial=True)   
    trt = time.time() - trt

    checkpoint_path = ("./saves/" + args.dataset + "/" + args.network + "/" + args.attacktype + "/checkpoint")
    checkpoint = torch.load(checkpoint_path)
    manager.load_model(checkpoint)

    adv_path = root_save_path + 'adversarial_' + str(f_b_start)
    manager.save_model(adv_path)

    baseline_acc = manager.eval(adversarial=True, data="Testing")  
    baseline_acc = 100 - baseline_acc[0]
    print("baseline acc:", baseline_acc)

    ### Get baseline activations of f_b*
    manager.evalLincom(adversarial=True, store=True, lincom=False, data="Testing")
    
#########################################################################################
###    Run Experiment
#########################################################################################
    ### As is this is loading the network with f_a* and f_b, however f_b is only pretrained
    manager.freeze_fa(99)
    lincom_initial_acc = manager.eval(adversarial=True, data="Testing", lincom=True)
    lincom_initial_acc = 100 - lincom_initial_acc[0]
    
    lincom_optimal_acc = manager.train_lincom(save=False, target_accuracy=baseline_acc, adversarial=True, eps=1e-16)   

    
    print("baseline acc:", baseline_acc)
    print("initial lincom acc:", lincom_initial_acc)
    print("optimized lincom acc:", lincom_optimal_acc)
    print("Difference from Baseline Acc: ", baseline_acc - lincom_optimal_acc)

#########################################################################################
###    Save Experiment Results
#########################################################################################
    base_path = ("./saves/" + args.dataset + "/" + args.network + "/" + args.attacktype)
    os.makedirs(base_path, exist_ok=True)    
    
    loss_path = (base_path + "/lincom_loss")
    lambda_path = (base_path + "/lambdas")
    os.makedirs(loss_path, exist_ok=True)   
    os.makedirs(lambda_path, exist_ok=True)   
    np.save((lambda_path + "/opt_lambdas.npy"),manager.model.lincom1.lambdas)
    np.save((loss_path + "/lincom_loss.npy"),manager.loss)
    
    results_path = ("./results/" + args.dataset + "/" + args.network + "/" + args.attacktype)
    os.makedirs(results_path, exist_ok=True)    
    results = np.asarray([baseline_acc, lincom_initial_acc, lincom_optimal_acc, (baseline_acc-lincom_optimal_acc)])
    if args.avg=="True":
        np.save((results_path + "/results_avg_solution.npy"), results)
    else:
        np.save((results_path + "/results.npy"), results)

if __name__ == '__main__':
    main()
